Add Full TE Spec support for Megatron Pruning DynamicModules + MoE bug fixes#1024
Add Full TE Spec support for Megatron Pruning DynamicModules + MoE bug fixes#1024kevalmorabia97 merged 9 commits intomainfrom
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds Transformer Engine (TE) support across Megatron NAS and Minitron pruning: consolidates TE dynamic modules into the Megatron plugin, extracts fused TELayerNorm activations per rank, switches spec factories to TE factories, updates pruning searcher to local per-rank activations and candidate caching, and adapts examples and tests to TE paths. Changes
Sequence Diagram(s)sequenceDiagram
participant Searcher as MCoreMinitronSearcher
participant Model as Megatron/TE Model
participant Hook as ActivationHook
participant Checkpoint as Per-rank Checkpoint
participant Pruner as PruneRoutine
rect rgba(100,149,237,0.5)
Searcher->>Model: register ActivationHook on TELayerNormColumnParallelLinear
Model-->>Hook: fused layernorm+linear forward outputs
Hook->>Searcher: collect per-module activations (local_activations)
end
rect rgba(60,179,113,0.5)
Searcher->>Checkpoint: set_local_activations_and_layer_scores(local_activations, layer_scores)
Checkpoint-->>Searcher: saved per-rank activations
Checkpoint->>Searcher: load local_activations for run_search
end
rect rgba(255,140,0,0.5)
Searcher->>Pruner: invoke _prune using collected scores
Pruner->>Model: apply pruning masks (no early break)
Pruner->>Model: reinitialize token dispatcher if needed
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1024 +/- ##
==========================================
- Coverage 70.30% 70.29% -0.02%
==========================================
Files 227 227
Lines 25857 25857
==========================================
- Hits 18179 18176 -3
- Misses 7678 7681 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
8f42e0f to
cff7137
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/prune/plugins/mcore_minitron.py (1)
918-935: Add checkpoint key validation before restoring_activations.
m._activations = local_activations[n]will raise a rawKeyErroron checkpoint/model drift. A pre-check with a clear error message will make resume failures diagnosable.Proposed hardening
def set_local_activations_and_layer_scores( @@ print_rank_0("Loading activations and scores from per-rank checkpoint...") for layer in self.model.decoder.layers: layer._scores = layer_scores[layer.layer_number] + expected_keys = [ + n for n, m in self.model.named_modules() if hasattr(m, "_activations") + ] + missing = [k for k in expected_keys if k not in local_activations] + if missing: + raise KeyError( + f"Missing activation entries for modules: {missing[:8]}" + + (" ..." if len(missing) > 8 else "") + ) for n, m in self.model.named_modules(): if hasattr(m, "_activations"): m._activations = local_activations[n]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 918 - 935, The current set_local_activations_and_layer_scores method assigns m._activations = local_activations[n] without validating the key, which will raise a raw KeyError on checkpoint/model drift; update set_local_activations_and_layer_scores to check if n is in local_activations before assignment (iterate over self.model.named_modules()), and if missing either raise a clear ValueError that includes the module name n and a summary of available keys (e.g., list(local_activations.keys())) or log a descriptive warning and skip restoring that module, so failures are diagnosable; reference the method name set_local_activations_and_layer_scores and attributes _activations, local_activations, and model.named_modules() when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 870-874: You’re overwriting TE modules’ return_layernorm_output
state unconditionally; instead capture each TELayerNormColumnParallelLinear’s
original return_layernorm_output before changing it (e.g., store in a dict keyed
by id(module) or attach a private attribute like _orig_return_layernorm_output
on the module) when you set it to True, and in the cleanup loop restore each
module’s original value rather than forcing False; apply the same
save-and-restore pattern for the other similar block referenced (lines
~999-1006) so original behavior is preserved after pruning/search.
---
Nitpick comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 918-935: The current set_local_activations_and_layer_scores method
assigns m._activations = local_activations[n] without validating the key, which
will raise a raw KeyError on checkpoint/model drift; update
set_local_activations_and_layer_scores to check if n is in local_activations
before assignment (iterate over self.model.named_modules()), and if missing
either raise a clear ValueError that includes the module name n and a summary of
available keys (e.g., list(local_activations.keys())) or log a descriptive
warning and skip restoring that module, so failures are diagnosable; reference
the method name set_local_activations_and_layer_scores and attributes
_activations, local_activations, and model.named_modules() when making the
change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 7ec3047d-20d8-40ab-b3d9-6a92aa5ec6c0
📒 Files selected for processing (2)
modelopt/torch/prune/plugins/mcore_minitron.pytests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
ChenhanYu
left a comment
There was a problem hiding this comment.
Review: Add Full TE Spec support for Megatron Pruning DynamicModules + MoE bug fixes
Summary
This PR migrates Megatron pruning from custom ModelOpt spec to standard Transformer Engine (TE) spec, enabling retirement of the ModelOpt spec. Key changes:
- New TE-specific DynamicModule classes (
_DynamicTEParallelLinear,_DynamicTEColumnParallelLinear, etc.) that use TE'sin_features/out_featuresnaming - Fused LayerNorm handling via
return_layernorm_output=Truepatching onTELayerNormColumnParallelLinear - Activation collection simplified from allgather across PP ranks to per-rank local storage
- Two MoE bug fixes: (1) token dispatcher reinit was only applied to the first MoE layer (
breakremoval), (2)pre_mlp_layernormhooks were missing for MoE layers
Overall well-structured with thorough testing.
Attention Required
-
TE is now a hard requirement for pruning. The
HAS_TEconditional was removed andDotProductAttentionfallback dropped —assert isinstance(self.core_attention, TEDotProductAttention)will crash without TE. This is an intentional design shift (ModelOpt spec retirement), but worth confirming no internal users still run without TE. -
Activation storage format changed (checkpoint-breaking).
activations_per_rank(list) →local_activations(dict). Existing pruning checkpoints saved with the old format will fail to load. The PR description doesn't mention a migration path.
| bf16=bf16, | ||
| # MoE-specific parameters | ||
| moe_router_dtype=None, | ||
| moe_grouped_gemm=moe_grouped_gemm, |
There was a problem hiding this comment.
Changed from "fp32" to None. Is this intentional or an unrelated cleanup? Running MoE routers without explicit fp32 dtype can affect numerical stability.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/_test_utils/torch/megatron/models.py (1)
308-309: Reject unknowntransformer_implvalues in the Mamba helper.Unlike
get_mcore_gpt_model, this helper treats every non-"transformer_engine"value as"modelopt". A typo or a caller passing"local"will silently stop exercising the intended path.♻️ Proposed fix
def get_mcore_mamba_hybrid_model( @@ sequence_parallel: bool = False, transformer_impl: str = "modelopt", @@ ) -> MambaModel: @@ """ assert HAS_MAMBA, "Mamba not installed" + assert transformer_impl in ["modelopt", "transformer_engine"] @@ if transformer_impl == "transformer_engine": mamba_spec = get_te_mamba_stack_spec(moe_grouped_gemm=moe_grouped_gemm) - else: + else: # transformer_impl == "modelopt" mamba_spec = get_mamba_stack_modelopt_spec(remap_te_layernorm=True)Also applies to: 390-393
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/_test_utils/torch/megatron/models.py` around lines 308 - 309, The Mamba helper currently treats any transformer_impl value other than "transformer_engine" as "modelopt", allowing silent typos; update the helper(s) that read the transformer_impl parameter (around the transformer_impl = "modelopt" default and the code at the other occurrence) to perform explicit validation: accept only the supported strings ("modelopt" and "transformer_engine") and raise a ValueError with a clear message if an unknown value is passed, rather than silently defaulting—adjust both occurrences (lines near transformer_impl and the block at 390-393) to enforce this check.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py`:
- Around line 74-75: The tests set transformer_impl="transformer_engine" but
lack a guard, so add a Transformer Engine availability skip: either call
pytest.importorskip("megatron.core.extensions.transformer_engine") at module
scope in the test module (alongside the existing skip_if_no_mamba() call) or
implement a helper skip_if_no_transformer_engine() in
_test_utils/import_helper.py and invoke it next to skip_if_no_mamba(); ensure
the check runs before any model construction that uses transformer_impl to avoid
runtime failures.
---
Nitpick comments:
In `@tests/_test_utils/torch/megatron/models.py`:
- Around line 308-309: The Mamba helper currently treats any transformer_impl
value other than "transformer_engine" as "modelopt", allowing silent typos;
update the helper(s) that read the transformer_impl parameter (around the
transformer_impl = "modelopt" default and the code at the other occurrence) to
perform explicit validation: accept only the supported strings ("modelopt" and
"transformer_engine") and raise a ValueError with a clear message if an unknown
value is passed, rather than silently defaulting—adjust both occurrences (lines
near transformer_impl and the block at 390-393) to enforce this check.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d8683652-f5b5-467a-9641-54568a74e6f6
📒 Files selected for processing (6)
CHANGELOG.rstexamples/megatron_bridge/prune_minitron.pymodelopt/torch/nas/plugins/megatron.pymodelopt/torch/prune/plugins/mcore_minitron.pytests/_test_utils/torch/megatron/models.pytests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
🚧 Files skipped from review as they are similar to previous changes (4)
- examples/megatron_bridge/prune_minitron.py
- CHANGELOG.rst
- modelopt/torch/nas/plugins/megatron.py
- modelopt/torch/prune/plugins/mcore_minitron.py
tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
Show resolved
Hide resolved
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
cf71476 to
067e80a
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/prune/plugins/mcore_minitron.py (1)
477-513:⚠️ Potential issue | 🟡 MinorMake candidate validation exception-safe.
This block temporarily disables
enable_expert_biasand mutates the model into each candidate subnet, but restoration only happens on the happy path. If_prune()oreval_score()raises, the process keeps the patched router flags and partially mutated model state. Wrap the validation section in atry/finallyand restore the router flags / max subnet there.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 477 - 513, The validation loop temporarily disables router flags and mutates the model but only restores them on the happy path; make it exception-safe by wrapping the candidate validation loop (the logic that iterates top_k_candidates and calls _prune, eval_score, sample, and saves/restores model.decoder.layers) in a try/finally so that in the finally you: (1) re-enable every module in _routers_with_expert_bias by setting m.enable_expert_bias = True, and (2) restore the model to the max subnet and original layer numbering (use sample(self.model, sample_func=max) and reset layer.layer_number from the saved start_layer_number and reassign self.model.decoder.layers) to ensure cleanup runs even if _prune() or eval_score() throws.
♻️ Duplicate comments (2)
modelopt/torch/prune/plugins/mcore_minitron.py (1)
277-292:⚠️ Potential issue | 🟠 MajorRestore TE
return_layernorm_outputto the original value.The registration path flips every fused TE module to
True, butcleanup()always forcesFalse. Any module that started asTruewill change behavior after pruning/search. Please store the original flag per module and restore that exact value;run_search()should also callcleanup()from afinallyblock so exceptions don’t leak the patch.Proposed fix
class ImportanceEstimatorRegistry: def __init__(self, model: DynamicModule): """Initialize the registry.""" assert isinstance(model, _DynamicMCoreLanguageModel), "Model must be a DynamicModule" self.model = model self._hooks: list[tuple[nn.Module, Any]] = [] # List of (module, hook_handle) tuples + self._te_ln_linear_prev_flags: dict[nn.Module, bool] = {} @@ def cleanup(self) -> None: """Remove all registered hooks and temporary attributes.""" # Remove all hooks for _, handle in self._hooks: handle.remove() self._hooks.clear() - # Unpatch return_layernorm_output on fused TELayerNormColumnParallelLinear modules - for m in self.model.modules(): - if isinstance(m, TELayerNormColumnParallelLinear): - m.return_layernorm_output = False + for m, prev_flag in self._te_ln_linear_prev_flags.items(): + m.return_layernorm_output = prev_flag + self._te_ln_linear_prev_flags.clear() @@ for m in module.modules(): if isinstance(m, TELayerNormColumnParallelLinear): + if m not in registry._te_ln_linear_prev_flags: + registry._te_ln_linear_prev_flags[m] = m.return_layernorm_output m.return_layernorm_output = TrueAlso applies to: 881-885, 1014-1016
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 277 - 292, The cleanup flow currently forces every fused TE module's return_layernorm_output to False and the registration path sets them to True, which alters module behavior after pruning; update the code (in registration/patching logic that flips fused TE modules and in cleanup()) to record each module's original return_layernorm_output value (e.g., store a per-module map keyed by module identity) and restore that exact original boolean in cleanup() instead of hardcoding False, and ensure run_search() invokes cleanup() inside a finally block so the original flags are restored even if an exception occurs; refer to functions/methods cleanup(), run_search(), and the return_layernorm_output flag on the fused TE modules when making the changes.tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (1)
74-75:⚠️ Potential issue | 🟡 MinorGuard these TE-only tests before model/plugin import.
These cases now force
transformer_impl="transformer_engine", but the module still only callsskip_if_no_mamba(). In environments with Mamba installed and Transformer Engine absent, this file will fail instead of skipping cleanly. Addpytest.importorskip("megatron.core.extensions.transformer_engine")or a shared helper next toskip_if_no_mamba().Also applies to: 155-156, 274-275
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py` around lines 74 - 75, The tests set transformer_impl="transformer_engine" but only call skip_if_no_mamba(), which lets cases run when Mamba exists but Transformer Engine does not; add a guard that imports or skips the TE module before importing or instantiating the model/plugin (e.g., call pytest.importorskip("megatron.core.extensions.transformer_engine") or create/use a shared helper that does that) and apply the same change for the other occurrences where transformer_impl="transformer_engine" is used; keep the existing skip_if_no_mamba() but ensure the TE import-or-skip runs earlier to prevent import failures.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Line 37: Wrap the module-level import of TELayerNormColumnParallelLinear in a
try/except that sets TELayerNormColumnParallelLinear = None and a HAS_TE boolean
flag (True on success, False on ImportError), then conditionally execute any
Transformer-Engine-specific logic only when HAS_TE is True; specifically guard
every place that references TELayerNormColumnParallelLinear (e.g., isinstance
checks and TE hook registrations) and the TE-specific hook registration blocks
so they run only if HAS_TE is True.
In `@modelopt/torch/utils/plugins/mbridge.py`:
- Around line 95-103: The else branch calls provider.num_moe_experts and
provider.qk_layernorm without defensive checks; update the branch around
transformer_layer_spec assignment (get_gpt_layer_with_transformer_engine_spec)
to either guard these attribute accesses with hasattr(provider,
"num_moe_experts") and hasattr(provider, "qk_layernorm") and handle missing
attributes (fallback values or raise a clear error), or add a concise
comment/docstring near the else branch documenting the assumption that provider
is a GPTModelProvider with those attributes (and cite the external megatron
source). Ensure references to provider.num_moe_experts, provider.qk_layernorm,
and get_gpt_layer_with_transformer_engine_spec are updated accordingly so future
callers see the validation or documented contract.
In `@tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py`:
- Around line 75-76: The test hardcodes transformer_impl="transformer_engine"
but only calls skip_if_no_mamba(), so when Mamba exists and TE does not the test
will fail; add a new helper skip_if_no_transformer_engine() to
tests/_test_utils/import_helper.py (modeled on the TE guards used in
test_megatron.py around line ~964) and call skip_if_no_transformer_engine() at
the module level in
tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py
alongside the existing skip_if_no_mamba() so the test is cleanly skipped when
Transformer Engine is unavailable.
---
Outside diff comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 477-513: The validation loop temporarily disables router flags and
mutates the model but only restores them on the happy path; make it
exception-safe by wrapping the candidate validation loop (the logic that
iterates top_k_candidates and calls _prune, eval_score, sample, and
saves/restores model.decoder.layers) in a try/finally so that in the finally
you: (1) re-enable every module in _routers_with_expert_bias by setting
m.enable_expert_bias = True, and (2) restore the model to the max subnet and
original layer numbering (use sample(self.model, sample_func=max) and reset
layer.layer_number from the saved start_layer_number and reassign
self.model.decoder.layers) to ensure cleanup runs even if _prune() or
eval_score() throws.
---
Duplicate comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 277-292: The cleanup flow currently forces every fused TE module's
return_layernorm_output to False and the registration path sets them to True,
which alters module behavior after pruning; update the code (in
registration/patching logic that flips fused TE modules and in cleanup()) to
record each module's original return_layernorm_output value (e.g., store a
per-module map keyed by module identity) and restore that exact original boolean
in cleanup() instead of hardcoding False, and ensure run_search() invokes
cleanup() inside a finally block so the original flags are restored even if an
exception occurs; refer to functions/methods cleanup(), run_search(), and the
return_layernorm_output flag on the fused TE modules when making the changes.
In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py`:
- Around line 74-75: The tests set transformer_impl="transformer_engine" but
only call skip_if_no_mamba(), which lets cases run when Mamba exists but
Transformer Engine does not; add a guard that imports or skips the TE module
before importing or instantiating the model/plugin (e.g., call
pytest.importorskip("megatron.core.extensions.transformer_engine") or create/use
a shared helper that does that) and apply the same change for the other
occurrences where transformer_impl="transformer_engine" is used; keep the
existing skip_if_no_mamba() but ensure the TE import-or-skip runs earlier to
prevent import failures.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: bb96a786-c906-4909-a05e-e62d8eccbf69
📒 Files selected for processing (12)
CHANGELOG.rstexamples/megatron_bridge/README.mdexamples/megatron_bridge/prune_minitron.pymodelopt/torch/nas/plugins/megatron.pymodelopt/torch/nas/plugins/transformer_engine.pymodelopt/torch/prune/plugins/mcore_minitron.pymodelopt/torch/utils/plugins/mbridge.pytests/_test_utils/torch/megatron/models.pytests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.pytests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.pytests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.pytests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
💤 Files with no reviewable changes (1)
- modelopt/torch/nas/plugins/transformer_engine.py
✅ Files skipped from review due to trivial changes (2)
- modelopt/torch/nas/plugins/megatron.py
- tests/_test_utils/torch/megatron/models.py
🚧 Files skipped from review as they are similar to previous changes (5)
- examples/megatron_bridge/prune_minitron.py
- tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
- tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py
- examples/megatron_bridge/README.md
- CHANGELOG.rst
tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/prune/plugins/mcore_minitron.py (1)
476-513:⚠️ Potential issue | 🟠 MajorRestore
enable_expert_biasin afinallyblock.If
eval_score()or checkpoint saving raises inside this loop, the model is left with expert bias permanently disabled for the rest of the search.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 476 - 513, The loop temporarily disables module expert bias by setting m.enable_expert_bias = False and collects them in _routers_with_expert_bias, but if eval_score() or save_search_checkpoint() throws the routers remain disabled; wrap the validation phase (the for candidate in tqdm(...) loop and its inner pruning/eval/save logic) in a try/finally (or surround the whole top-k validation block) and in the finally re-enable each router by setting m.enable_expert_bias = True (using the existing _routers_with_expert_bias list) so expert bias is always restored even on exceptions.
♻️ Duplicate comments (2)
modelopt/torch/prune/plugins/mcore_minitron.py (2)
37-37:⚠️ Potential issue | 🟠 MajorGate the TE-specific import like the rest of the optional stack.
This makes the pruning plugin fail import whenever Transformer Engine is absent, even if the caller never exercises the TE path. Please wrap the import in
try/except ImportErrorand guard the TE-specific hook logic with a feature flag.As per coding guidelines,
modelopt/**/*.py: Avoid hard imports of optional dependencies at module level; gate features by install extras ([onnx],[hf],[all]).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/prune/plugins/mcore_minitron.py` at line 37, The module currently hard-imports TELayerNormColumnParallelLinear which causes import failure when Transformer Engine (TE) is absent; wrap the import in a try/except ImportError and set a module-level feature flag (e.g., _HAS_TE = True/False) accordingly, then guard all TE-specific logic/hook usage (references to TELayerNormColumnParallelLinear and any TE-only hooks in the plugin class in mcore_minitron.py) behind that flag so importing the module succeeds even when TE is not installed.
881-885:⚠️ Potential issue | 🟠 MajorRestore each module’s original
return_layernorm_outputvalue.This still patches every fused TE module to
Trueand later forcesFalseglobally. Any module that started withTrueis silently mutated after cleanup, and the blanket patch is broader than the modules you actually hook.Also applies to: 1014-1016
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 881 - 885, The cleanup loop currently forces return_layernorm_output=False on every TELayerNormColumnParallelLinear in self.model, mutating modules that may have originally been True; instead track which specific modules you patched (e.g., store them in a list when you set return_layernorm_output=True where you install hooks) and in the cleanup only restore each tracked module's original value (store the original value per-module when patching). Update the code that sets return_layernorm_output to record (module, original_value) and replace the current iteration over self.model.modules() with iteration over that recorded list to restore original_value on the same TELayerNormColumnParallelLinear instances you modified.
🧹 Nitpick comments (1)
modelopt/torch/prune/plugins/mcore_minitron.py (1)
186-190: The checkpoint payload types no longer match the data you store.
local_activationsis not uniformlyTensorhere: the hidden-size collector stores a nested dict keyed by module id, the SequentialMLP collector stores a dict of tensors, andlayer_scoresare Python floats (.item()), not tensors. Tighten these aliases and signatures so mypy and checkpoint consumers see the real payload shape.As per coding guidelines,
**/*.py: Use mypy for type checking on Python code (configured inpyproject.toml).Also applies to: 208-214, 914-946
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 186 - 190, The checkpoint payload type hints are incorrect: adjust the annotations for local_activations and layer_scores (and any related aliases near sorted_layers and all_candidates_per_constraint) to match what is actually stored — make local_activations a union type reflecting both nested dicts (dict[str, dict[str, torch.Tensor]]) and flat dicts (dict[str, torch.Tensor]) produced by the hidden-size collector and SequentialMLP collector, and change layer_scores from dict[int, torch.Tensor] to dict[int, float] (since .item() is stored); ensure sorted_layers remains Optional[List[int]] (1-indexed) and all_candidates_per_constraint stays dict[float, list[CandidateSubnet]]; update any type aliases or function signatures that construct or consume these payloads so mypy and checkpoint serializers see the real shapes (search for local_activations, layer_scores, sorted_layers, all_candidates_per_constraint, CandidateSubnet to locate all usages).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/nas/plugins/megatron.py`:
- Around line 390-403: The specialized TE parallel linear class registers only
"out_features" (via self._register_dynamic_attribute("out_features", ...)) but
not "in_features", causing stale dimension metadata after slicing; update the
class to also call self._register_dynamic_attribute("in_features", <callable>)
mirroring the parent _DynamicTEParallelLinear behavior (compute input dim from
mod.config.kv_channels and any active/group counts as appropriate), and ensure
"in_features" and "out_features" are kept in sync with the sliced weight/bias
(use the same pattern/closures used for "out_features" and for
bias/_get_weight/_get_bias/_get_ln_param to compute current dimensions).
- Around line 25-31: The module currently imports transformer_engine and the
Megatron TE extension classes (transformer_engine, TEColumnParallelLinear,
TEDotProductAttention, TELayerNormColumnParallelLinear, TERowParallelLinear)
unconditionally which breaks imports when TE is absent; wrap these imports in a
try/except and set a HAS_TE boolean (mirroring the HAS_MAMBA pattern already in
this file) so downstream registration of TE-specific features only occurs when
HAS_TE is True, and ensure any code that references those TE symbols is guarded
by the HAS_TE check similar to the pattern used in
modelopt/torch/quantization/plugins/megatron.py.
---
Outside diff comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 476-513: The loop temporarily disables module expert bias by
setting m.enable_expert_bias = False and collects them in
_routers_with_expert_bias, but if eval_score() or save_search_checkpoint()
throws the routers remain disabled; wrap the validation phase (the for candidate
in tqdm(...) loop and its inner pruning/eval/save logic) in a try/finally (or
surround the whole top-k validation block) and in the finally re-enable each
router by setting m.enable_expert_bias = True (using the existing
_routers_with_expert_bias list) so expert bias is always restored even on
exceptions.
---
Duplicate comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Line 37: The module currently hard-imports TELayerNormColumnParallelLinear
which causes import failure when Transformer Engine (TE) is absent; wrap the
import in a try/except ImportError and set a module-level feature flag (e.g.,
_HAS_TE = True/False) accordingly, then guard all TE-specific logic/hook usage
(references to TELayerNormColumnParallelLinear and any TE-only hooks in the
plugin class in mcore_minitron.py) behind that flag so importing the module
succeeds even when TE is not installed.
- Around line 881-885: The cleanup loop currently forces
return_layernorm_output=False on every TELayerNormColumnParallelLinear in
self.model, mutating modules that may have originally been True; instead track
which specific modules you patched (e.g., store them in a list when you set
return_layernorm_output=True where you install hooks) and in the cleanup only
restore each tracked module's original value (store the original value
per-module when patching). Update the code that sets return_layernorm_output to
record (module, original_value) and replace the current iteration over
self.model.modules() with iteration over that recorded list to restore
original_value on the same TELayerNormColumnParallelLinear instances you
modified.
---
Nitpick comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 186-190: The checkpoint payload type hints are incorrect: adjust
the annotations for local_activations and layer_scores (and any related aliases
near sorted_layers and all_candidates_per_constraint) to match what is actually
stored — make local_activations a union type reflecting both nested dicts
(dict[str, dict[str, torch.Tensor]]) and flat dicts (dict[str, torch.Tensor])
produced by the hidden-size collector and SequentialMLP collector, and change
layer_scores from dict[int, torch.Tensor] to dict[int, float] (since .item() is
stored); ensure sorted_layers remains Optional[List[int]] (1-indexed) and
all_candidates_per_constraint stays dict[float, list[CandidateSubnet]]; update
any type aliases or function signatures that construct or consume these payloads
so mypy and checkpoint serializers see the real shapes (search for
local_activations, layer_scores, sorted_layers, all_candidates_per_constraint,
CandidateSubnet to locate all usages).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8d87f835-546c-4fe1-995b-e21625d0dee5
📒 Files selected for processing (6)
CHANGELOG.rstexamples/megatron_bridge/prune_minitron.pymodelopt/torch/nas/plugins/megatron.pymodelopt/torch/prune/plugins/mcore_minitron.pytests/_test_utils/torch/megatron/models.pytests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
✅ Files skipped from review due to trivial changes (1)
- tests/_test_utils/torch/megatron/models.py
🚧 Files skipped from review as they are similar to previous changes (3)
- examples/megatron_bridge/prune_minitron.py
- tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
- CHANGELOG.rst
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
067e80a to
74184b8
Compare
| # MoE-specific parameters | ||
| moe_router_dtype=None, | ||
| moe_grouped_gemm=moe_grouped_gemm, | ||
| moe_router_dtype="fp32", |
There was a problem hiding this comment.
MOE router type FP32 is needed in Nemotron models
There was a problem hiding this comment.
I was seeing some issue for the pruning test model so didnt set it here. For Nemotron3 Nano pruning on actual model, I used whatever it has default and works fine
What does this PR do?
Type of change: Improvement + Bug Fix
Quantization recently added support for Full TE spec. Adding same for Pruning as well so we can retire ModelOpt spec and just use standard TE spec.
NOTE: We still dont support TEGroupedGemm and instead use TE SequentialMLP for now (but this can be configured in standard TE Spec so we dont need modelopt spec)
Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
[Bug fix]: Previously NAS-based pruning for MoE models would hang when evaluating MMLU for pruned candidate models because of a bug. Fixed in this PR as well
[Bug fix]: Previously hidden size importance hooks were not applied to pre_mlp_layernorm for MoE layers. Fixed in this PR as well resulting in a significant improvement in MMLU for Qwen3-30B-A3B
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
OMNIML-3504
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Improvements
Tests